import os
import pandas as pd
import argparse
from utils import *
from model import GNNs
from training import train_model
from earlystopping import stopping_args
from propagation import *
from load_data import *

if __name__ == "__main__":
    parse = argparse.ArgumentParser()
    # some system settings
    parse.add_argument("--print_interval", help="The print interval", type=int, default=50)
    parse.add_argument("--device", help="GPU device", type=str, default="0")
    parse.add_argument("--runs", help="runs", type=int, default=10)

    # dataset settings and experiment settings
    parse.add_argument("-d", "--dataset", help="dataset", type=str, default='cora')
    parse.add_argument("-l", "--labelrate", help="labeled data for train per class", type=int, default=20)
    parse.add_argument("--train_labelrate", help="labeled rate of training set", type=float, default=0.48)
    parse.add_argument("--val_labelrate", help="labeled data of validation set", type=float, default=0.32)
    parse.add_argument("--test_labelrate", help="labeled data of testing set", type=float, default=0.2)
    parse.add_argument("--max_epochs", help="The number of max epochs", type=int, default=500)
    parse.add_argument("--patience", help="The number of patience for early stopping", type=int, default=50)
    parse.add_argument("--random_split", help="use the random split", type=bool, default=True)
    parse.add_argument("--random_seed", help="random seed", type=bool, default=False)

    # model settings
    parse.add_argument("--model", help="APGNN,PPNP,APPNP,GNN-LF,GNN-HF", type=str, default="APGNN")
    parse.add_argument("--niter", help="times for iteration", type=int, default=10)
    parse.add_argument("--npow", help="P-hop", type=int, default=1)
    parse.add_argument("--reg_lambda", help="L2 regularization for the parameters", type=float, default=0.005)
    parse.add_argument("--lr", help="learning rate", type=float, default=0.01)
    parse.add_argument("--dropout", help="dropout rate", type=float, default=0.8)
    parse.add_argument("--alpha", help="the para 'alpha' for APGNN", type=float, default=0.9)
    parse.add_argument("--beta", help="the para 'beta' for GNN-LF/HF", type=float, default=0.9)
    parse.add_argument("--use_residual", help="use residual", type=bool, default=False)

    args = parse.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    # print the args
    print(args)
    filename = str(args.model) + '_' + str(args.dataset) + '_labelrate_' + str(args.labelrate) + '.txt'
    f = open(filename, 'a+')
    f.write('\n'+str(args))

    # load the dataset
    if args.dataset == 'acm':
        graph, idx_np = load_new_data_acm(args.labelrate)
    elif args.dataset == 'wiki':
        graph, idx_np = load_new_data_wiki(args.labelrate)
    elif args.dataset == 'ms':
        graph, idx_np = load_new_data_ms(args.labelrate)
    elif args.dataset in ['chameleon', 'squirrel', 'cornell', 'texas', 'wisconsin', 'film']:
        graph, idx_np = load_new_data(args.dataset, args.train_labelrate, args.val_labelrate, args.test_labelrate
                                      , args.random_seed, args.random_split)
    else:
        if args.dataset == 'cora':
            feature_dim = 1433
        elif args.dataset == 'citeseer':
            feature_dim = 3703
        elif args.dataset == 'pubmed':
            feature_dim = 500
        graph, idx_np = load_new_data_tkipf(args.dataset, feature_dim, args.labelrate)

    # GA_csr = graph.adj_matrix  # 假设 graph.adj_matrix 是一个 CSR_matrix
    # length = GA_csr.shape[0]
    # labels = graph.labels  # 假设有一个 labels 数组
    #
    # # 预计算每个节点的邻接节点数量
    # adj_counts = np.diff(GA_csr.indptr)
    #
    # all_diff = np.zeros(length)
    # for i in range(length):
    #     row_start = GA_csr.indptr[i]
    #     row_end = GA_csr.indptr[i + 1]
    #     adj_nodes = GA_csr.indices[row_start:row_end]
    #     num = adj_counts[i]
    #
    #     if num != 0:
    #         diff = np.sum(labels[i] == labels[adj_nodes])
    #         all_diff[i] = diff / num
    #
    # hdegree = np.mean(all_diff)
    #
    # print(hdegree)




    # print_interval = args.print_interval
    device = 'cuda'
    test = True
    propagation = []

    stopping_args['max_epochs'] = args.max_epochs
    stopping_args['patience'] = args.patience

    nclasses = len(np.unique(graph.labels))

    results = []
    i_tot = 0
    # the number of runs for experiments
    average_time = args.runs
    for _ in range(average_time):
        i_tot += 1
        if args.model == 'APGNN':
            propagation = APGNN(graph.adj_matrix, niter=args.niter, npow=args.npow, drop_prob=None,
                                 use_residual = args.use_residual, alpha = args.alpha)
        elif args.model == 'GNN-LF':
            propagation = LFPowerIteration(graph.adj_matrix, alpha=args.alpha, mu=args.beta, niter=args.niter)
        elif args.model == 'GNN-HF':
            propagation = HFPowerIteration(graph.adj_matrix, alpha=args.alpha, beta=args.beta, niter=args.niter)
        elif args.model == 'APPNP':
            propagation = PPRPowerIteration(graph.adj_matrix, alpha=args.alpha, npow=1, niter=args.niter)
        elif args.model == 'PPNP':
            propagation = PPRExact(graph.adj_matrix, alpha=args.alpha)

        # propagation = PPRPowerIteration(graph.adj_matrix, alpha=0.1, npow=args.npow, niter=args.niter)

        model_args = {
            # The construction for hidden layers. For example: [128, 64, 32] or [64, 64]
            'hiddenunits': [64],
            'drop_prob': args.dropout,
            'propagation': propagation,
        }

        logging_string = f"Iteration {i_tot} of {average_time}"
        print(logging_string)
        _, result = train_model(idx_np,  args.dataset, GNNs, graph, model_args, args.lr, args.reg_lambda,
                stopping_args, test, device, None, args.print_interval, filename)
        results.append({})
        results[-1]['stopping_accuracy'] = result['early_stopping']['accuracy']
        results[-1]['valtest_accuracy'] = result['valtest']['accuracy']
        results[-1]['runtime'] = result['runtime']
        results[-1]['runtime_perepoch'] = result['runtime_perepoch']
        # print(torch.tanh(propagation.linear1.weight.t().unsqueeze(1).squeeze()))
        # f.write(str(torch.tanh(propagation.linear1.weight.t().unsqueeze(1).squeeze())))
        # print(propagation.linear2.weight.t().unsqueeze(1).squeeze())
        # print(propagation.linear3.weight.t().unsqueeze(1).squeeze())
        # f.write('\n'+str(result['valtest']['accuracy']*100))

    result_df = pd.DataFrame(results)
    result_df.head()

    stopping_acc = calc_uncertainty(result_df['stopping_accuracy'])
    valtest_acc = calc_uncertainty(result_df['valtest_accuracy'])
    runtime = calc_uncertainty(result_df['runtime'])
    runtime_perepoch = calc_uncertainty(result_df['runtime_perepoch'])

    f = open(filename, 'a+')

    print("APGNN" + "\n" 
          "Early stopping: Accuracy: {:.2f} ± {:.2f}%\n" 
          "{}: ACC: {:.2f} ± {:.2f}%\n"
          "Runtime: {:.3f} ± {:.3f} sec, per epoch: {:.2f} ± {:.2f}ms\n"
          .format(
              stopping_acc['mean'] * 100,
              stopping_acc['uncertainty'] * 100,
              'Test' if test else 'Validation',
              valtest_acc['mean'] * 100,
              valtest_acc['uncertainty'] * 100,
              runtime['mean'],
              runtime['uncertainty'],
              runtime_perepoch['mean'] * 1e3,
              runtime_perepoch['uncertainty'] * 1e3,
          ))


    f.write("\nAPGNN_" + "\n" 
          "Early stopping: Accuracy: {:.2f} ± {:.2f}%\n"
          "{}: ACC: {:.2f} ± {:.2f}%\n"
          "Runtime: {:.3f} ± {:.3f} sec, per epoch: {:.2f} ± {:.2f}ms\n\n"
          .format(
              stopping_acc['mean'] * 100,
              stopping_acc['uncertainty'] * 100,
              'Test' if test else 'Validation',
              valtest_acc['mean'] * 100,
              valtest_acc['uncertainty'] * 100,
              runtime['mean'],
              runtime['uncertainty'],
              runtime_perepoch['mean'] * 1e3,
              runtime_perepoch['uncertainty'] * 1e3,
          ))

